{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a07e0f5e",
   "metadata": {},
   "source": [
    "<div>\n",
    "<img src=\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\" width=\"40%\">  \n",
    "</div>\n",
    "\n",
    "# Distributed Bloom for Text Generation using Prompt Tuning\n",
    "\n",
    "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
    "\n",
    "We will adapt BLOOM for the task of creating a chatbot with a specific personality using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\n",
    "\n",
    "To use this notebook in Colab:\n",
    "\n",
    "1. Follow this link: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)\n",
    "2. Go to **Runtime** -> **Change runtime type** and select the GPU accelerator."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3f8526f",
   "metadata": {},
   "source": [
    "First, we have to prepare all dependencies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73bbc648",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -q petals datasets wandb scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4ab6ca7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "import wandb\n",
    "from datasets import load_dataset\n",
    "from tqdm import tqdm\n",
    "from torch.optim import AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import BloomTokenizerFast, get_scheduler\n",
    "\n",
    "from petals import DistributedBloomForCausalLM"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bf07b5d",
   "metadata": {},
   "source": [
    "Let's set some hyperparameters for training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f04ba4d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Choose a model you'd like to prompt-tune. We recommend starting with\n",
    "# the smaller 7.1B version of BLOOM (bigscience/bloom-7b1-petals) for faster prototyping.\n",
    "# Once your code is ready, you can switch to full-scale\n",
    "# 176B-parameter BLOOM (bigscience/bloom-petals) or BLOOMZ (bigscience/bloomz-petals).\n",
    "MODEL_NAME = \"bigscience/bloom-7b1-petals\"\n",
    "\n",
    "# Choose a prompt-tuning mode ('ptune' or 'deep_ptune').\n",
    "# The latter fine-tunes separate prefixes for each transformer block,\n",
    "# so prompt-tuning will take more time but yield better results.\n",
    "# See this paper for details of how it works: https://arxiv.org/pdf/2110.07602.pdf\n",
    "TUNING_MODE = 'ptune'\n",
    "\n",
    "NUM_PREFIX_TOKENS = 16\n",
    "DEVICE = 'cuda'\n",
    "BATCH_SIZE = 8\n",
    "LR = 1e-2\n",
    "WEIGHT_DECAY = 0.0\n",
    "NUM_SAMPLES = 1000\n",
    "SEED = 42\n",
    "MODEL_MAX_LENGTH = 256"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d38316bd",
   "metadata": {},
   "source": [
    "Prepare tokenizer and distributed model, connect it to servers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03c6e53e",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
    "tokenizer.padding_side = 'right'\n",
    "tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
    "model = DistributedBloomForCausalLM.from_pretrained(\n",
    "    MODEL_NAME,\n",
    "    pre_seq_len=NUM_PREFIX_TOKENS, \n",
    "    tuning_mode=TUNING_MODE\n",
    ").to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "042e3786",
   "metadata": {},
   "source": [
    "Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c44d516",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"bavard/personachat_truecased\")\n",
    "\n",
    "\n",
    "def chunking(examples):\n",
    "    inputs = [\n",
    "        \"\\n-----\\n\".join(history) + \"\\n-----\\n\" + candidate\n",
    "        for history, candidates in zip(examples[\"history\"], examples[\"candidates\"])\n",
    "        for candidate in candidates\n",
    "    ]\n",
    "    return {\"chunks\": inputs}\n",
    "\n",
    "\n",
    "def tokenize(examples):\n",
    "    outputs = {\n",
    "        \"input_ids\": tokenizer(examples[\"chunks\"], padding='max_length', truncation=True)[\"input_ids\"]\n",
    "    }\n",
    "    outputs[\"labels\"] = outputs[\"input_ids\"]\n",
    "    return outputs\n",
    "\n",
    "\n",
    "tokenized_datasets = (\n",
    "    dataset\n",
    "        .map(chunking, batched=True, remove_columns=dataset[\"train\"].column_names)\n",
    "        .map(tokenize, batched=True, remove_columns=[\"chunks\"])\n",
    ")\n",
    "\n",
    "\n",
    "tokenized_datasets.set_format(\"torch\")\n",
    "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset.select(list(range(NUM_SAMPLES))),\n",
    "    shuffle=True,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    drop_last=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef4323fd",
   "metadata": {},
   "source": [
    "Before setting up optimizers, check the model parameters that will be trained."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cc0ba34",
   "metadata": {},
   "outputs": [],
   "source": [
    "for n, p in model.named_parameters():\n",
    "    if p.requires_grad:\n",
    "        print(n, p.requires_grad, p.device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59cffce7",
   "metadata": {},
   "source": [
    "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef9bf344",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
    "\n",
    "lr_scheduler = get_scheduler(\n",
    "    name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "423c56d5",
   "metadata": {},
   "source": [
    "Let's initialize wandb for logging and start the training loop!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9e46807",
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.init(\n",
    "    project=\"bloom-personachat\",\n",
    "    config={\n",
    "        \"num_samples\": NUM_SAMPLES,\n",
    "        \"batch_size\": BATCH_SIZE,\n",
    "        \"learning_rate\": LR,\n",
    "        \"weight_decay\": WEIGHT_DECAY,\n",
    "        \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n",
    "        \"model_name\": MODEL_NAME,\n",
    "        \"seed\": SEED,\n",
    "    }\n",
    ")\n",
    "\n",
    "for batch in tqdm(train_dataloader):\n",
    "    batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
    "\n",
    "    model.train()\n",
    "    outputs = model(**batch)\n",
    "    loss = outputs.loss\n",
    "    loss.backward()\n",
    "\n",
    "    optimizer.step()\n",
    "    lr_scheduler.step()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    wandb.log({\"Train Loss\": loss})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f36cb80",
   "metadata": {},
   "source": [
    "Try to talk with the trained model! Submit an empty input to stop the execution.\n",
    "\n",
    "\n",
    "__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster \"interactive\" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "720181b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "TOP_K = 100\n",
    "TEMPERATURE = 0.6\n",
    "\n",
    "with model.inference_session(max_length=512) as sess:\n",
    "    while True:\n",
    "        user_phrase = input()\n",
    "        if len(user_phrase) == 0:\n",
    "            break\n",
    "        inputs = tokenizer([f\"{user_phrase}\\n-----\\n\"], return_tensors='pt')['input_ids'].to(DEVICE)\n",
    "        while True:\n",
    "            outputs = model.generate(\n",
    "                inputs,\n",
    "                temperature=TEMPERATURE,\n",
    "                do_sample=True,\n",
    "                top_k=TOP_K,\n",
    "                max_new_tokens=1,\n",
    "                session=sess,\n",
    "            )\n",
    "            bloom_answer_token = tokenizer.decode(outputs[0, -1:])\n",
    "            print(bloom_answer_token, end=\"\", flush=True)\n",
    "            if bloom_answer_token == \"\\n\":\n",
    "                break\n",
    "            inputs = None"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.9 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
